IBA Project

Dependencies

set.seed(84735)

library(rstanarm)
library(dplyr)
library(ggplot2)
library(stargazer)
library(bayesplot)
library(gridExtra)

dat <- read.csv("./fetal_health.csv")
dat$fh = as.factor((dat$fh == 1)*1)

Data

Summary Statistics

table(dat$fh)
## 
##    0    1 
##  471 1655
stargazer(dat,
          title = "Summary Table",
          type = type,
          out.header = FALSE)
Summary Table
Statistic N Mean St. Dev. Min Max
bv 2,126 133.304 9.841 106 160
acc 2,126 0.003 0.004 0.000 0.019
fm 2,126 0.009 0.047 0.000 0.481
uc 2,126 0.004 0.003 0.000 0.015
ld 2,126 0.002 0.003 0.000 0.015
sd 2,126 0.00000 0.0001 0.000 0.001
pd 2,126 0.0002 0.001 0.000 0.005
astv 2,126 46.990 17.193 12 87
mvstv 2,126 1.333 0.883 0.200 7.000
ptwaltv 2,126 9.847 18.397 0 91
mvltv 2,126 8.188 5.628 0.000 50.700
hw 2,126 70.446 38.956 3 180
hmi 2,126 93.579 29.560 50 159
hma 2,126 164.025 17.944 122 238
hnp 2,126 4.068 2.949 0 18
hnz 2,126 0.324 0.706 0 10
hmo 2,126 137.452 16.381 60 187
hme 2,126 134.611 15.594 73 182
hmed 2,126 138.090 14.467 77 186
hv 2,126 18.808 28.978 0 269
ht 2,126 0.320 0.611 -1 1

Visualization of the data

dat %>%
  tidyr::pivot_longer(
    cols = -fh,
    names_to = "va",
    values_to = "val"
  ) %>% 
  ggplot() +
  geom_boxplot(aes(x = val, y = fh, fill = fh)) +
  facet_wrap(~va, ncol = 3) + 
  theme(
    axis.title.x = element_blank(),
    legend.position = "none"
    )

caret::featurePlot(x = select(dat, !fh), 
            y = dat$fh,
            plot = "density", 
            scales = list(x = list(relation = "free"), 
                          y = list(relation = "free")), 
            adjust = 1.5, 
            pch = "|", 
            layout = c(3, 7))

Data Preprocessing

  1. We do a test-train split of 80-20.

  2. We scale and center all variables before fitting into the model.

dat <- 
  dat %>%
  mutate(across(!fh, scale))

split <-
  sample(
    c(FALSE, TRUE),
    size = nrow(dat),
    replace = T,
    prob = c(0.8, 0.2)
  )

test <- dat[split,]
dat <- dat[!split,]
table(test$fh)
## 
##   0   1 
## 101 335
table(dat$fh)
## 
##    0    1 
##  370 1320
stargazer(dat, title = "Scaled and Centered Train Data", type = type, out.header = FALSE)
Scaled and Centered Train Data
Statistic N Mean St. Dev. Min Max
bv 1,690 0.009 1.005 -2.775 2.713
acc 1,690 -0.012 0.991 -0.822 4.093
fm 1,690 -0.005 0.980 -0.203 10.018
uc 1,690 0.010 0.993 -1.482 3.609
ld 1,690 -0.00004 1.008 -0.638 4.429
sd 1,690 0.015 1.121 -0.057 17.395
pd 1,690 -0.011 0.978 -0.269 8.207
astv 1,690 0.001 0.992 -2.035 2.327
mvstv 1,690 0.001 1.002 -1.283 6.416
ptwaltv 1,690 0.002 0.994 -0.535 4.411
mvltv 1,690 0.001 0.988 -1.455 5.972
hw 1,690 -0.008 1.000 -1.731 2.812
hmi 1,690 0.011 1.002 -1.474 2.213
hma 1,690 0.002 1.009 -2.342 4.122
hnp 1,690 -0.010 0.991 -1.379 4.724
hnz 1,690 -0.004 0.987 -0.458 13.705
hmo 1,690 -0.00003 1.007 -4.728 3.025
hme 1,690 0.005 1.005 -3.951 2.911
hmed 1,690 0.003 0.998 -4.223 3.104
hv 1,690 -0.013 0.990 -0.649 8.116
ht 1,690 -0.004 1.003 -2.162 1.113
fh 1,690 0.781 0.414 0 1
stargazer(test, title = "Scaled and Centered Test Data", type = type, out.header = FALSE)
Scaled and Centered Test Data
Statistic N Mean St. Dev. Min Max
bv 436 -0.035 0.980 -2.775 2.611
acc 436 0.045 1.034 -0.822 3.576
fm 436 0.018 1.077 -0.203 10.104
uc 436 -0.040 1.029 -1.482 3.270
ld 436 0.0002 0.970 -0.638 4.091
sd 436 -0.057 0.000 -0.057 -0.057
pd 436 0.042 1.082 -0.269 8.207
astv 436 -0.004 1.031 -2.035 2.153
mvstv 436 -0.004 0.992 -1.283 5.624
ptwaltv 436 -0.008 1.023 -0.535 4.248
mvltv 436 -0.003 1.047 -1.455 7.553
hw 436 0.029 1.001 -1.629 2.042
hmi 436 -0.044 0.992 -1.474 2.179
hma 436 -0.009 0.964 -2.286 2.562
hnp 436 0.040 1.036 -1.379 4.046
hnz 436 0.016 1.050 -0.458 10.872
hmo 436 0.0001 0.975 -4.728 2.964
hme 436 -0.020 0.983 -3.502 3.039
hmed 436 -0.013 1.010 -3.601 3.312
hv 436 0.051 1.038 -0.649 8.634
ht 436 0.016 0.988 -2.162 1.113
fh 436 0.768 0.422 0 1

Logistic Regression with Gaussian Priors

The Model

\[ Likelihood:\\ fh_i|p_i \sim ~ Bernoulli(p_i)\\ \frac{p_i}{1-p_i} := exp({\beta_0 + \sum_{j = 1}^{21}X_{ij}\beta_j})\\ Priors:\\ \beta_0 \sim Normal(0,10)\\ \beta_j \sim Normal(0, 2.5) \\\forall j \in \{1, 2, ... ,21\} \]

These are very weak priors, especially considering it is a logistic regression.

When combined with the standardization, a standard deviation of 2.5 implies that the difference in odds should be (almost always) less than \(e^5\) when there is one standard deviation rise in the predictor. This seems more than reasonable.

On the other hand, we have set a weaker prior on the intercept and allow it to vary more freely to allow p to take any value between 0 and 1 easily for cases when the predictors are all average i.e. 0.

Below is the specification of the model in rstanarm:

nor_prior_mod <-
  stan_glm(
    fh ~ .,
    data = dat,
    family = binomial,
    prior_intercept = normal(0, 10),
    prior = normal(0, 2.5),
    chains =  4,
    iter = num_iter*4,
    seed = 84735,
    prior_PD = TRUE,
    diagnostic_file = file.path(tempdir(), "df.csv")
  )


nor_pos_mod <- update(nor_prior_mod, prior_PD = FALSE)

Draws from the prior:

plot(nor_prior_mod)

pp_check(nor_prior_mod, plotfun = "bars", nreps = 1000)

The Result

Convergence of MCMC chains

mcmc_trace(nor_pos_mod, facet_args = list(ncol = 3))

mcmc_dens_overlay(nor_pos_mod, facet_args = list(ncol = 3))

Based on the above plots, we can say that the chains have converged.

We can confirm this by looking at the R-hat values. The R-hat value is a ratio that quantifies the convergence of multiple Markov Chain Monte Carlo (MCMC) chains. It is the ratio of the pooled variance of parameter estimates between chains to the average within-chain variance. A value close to 1 indicates good convergence, while values significantly greater than 1 suggest potential convergence issues.

summary(nor_pos_mod)[,"Rhat"] %>% round(3)
##   (Intercept)            bv           acc            fm            uc 
##         1.001         1.001         1.001         1.000         1.000 
##            ld            sd            pd          astv         mvstv 
##         1.000         1.000         1.001         1.000         1.000 
##       ptwaltv         mvltv            hw           hmi           hma 
##         1.000         1.000         1.001         1.001         1.001 
##           hnp           hnz           hmo           hme          hmed 
##         1.000         1.000         1.001         1.001         1.001 
##            hv            ht      mean_PPD log-posterior 
##         1.000         1.000         1.001         1.000

All values are very close to 1. So, the chains have converged.

Estimates

summary(nor_pos_mod)[1:22,c(1, 3:6)] %>% round(digits = 2)
##              mean   sd   10%   50%   90%
## (Intercept)  4.25 0.34  3.83  4.24  4.69
## bv           0.00 0.30 -0.39  0.00  0.38
## acc          3.59 0.47  2.99  3.58  4.20
## fm          -0.29 0.14 -0.47 -0.29 -0.11
## uc           0.76 0.12  0.61  0.76  0.92
## ld           0.07 0.20 -0.17  0.07  0.33
## sd          -0.07 0.09 -0.19 -0.06  0.03
## pd          -1.55 0.20 -1.81 -1.55 -1.30
## astv        -1.62 0.18 -1.85 -1.61 -1.39
## mvstv        0.34 0.23  0.05  0.34  0.64
## ptwaltv     -0.48 0.11 -0.63 -0.48 -0.34
## mvltv       -0.18 0.20 -0.44 -0.19  0.07
## hw           0.07 1.87 -2.35  0.05  2.53
## hmi         -0.63 1.43 -2.47 -0.64  1.23
## hma         -0.77 0.88 -1.92 -0.77  0.38
## hnp         -0.26 0.16 -0.47 -0.26 -0.05
## hnz          0.00 0.11 -0.13  0.00  0.13
## hmo          1.11 0.38  0.63  1.10  1.60
## hme         -1.35 0.45 -1.92 -1.34 -0.76
## hmed         0.18 0.53 -0.51  0.18  0.85
## hv          -1.19 0.22 -1.48 -1.19 -0.90
## ht          -0.24 0.16 -0.44 -0.23 -0.03

Checking the Posterior Predictive Distribution

pp_check(nor_pos_mod, plotfun = "bars", nreps = 1000)

The above plot shows that our model fits the data well overall.

Next, we will look at the fit with respect to each predictor.

  1. Discretize each predictor into 10 groups with approximately equal frequency.

  2. Then plot the observed log odds of fh

  3. Plot a boxplot of the predicted sample log odds.

pred_log_odds <- posterior_linpred(nor_pos_mod)[seq(1, num_iter * 8, 100), ]
grid.arrange(
  grobs = generateLogOddsPlots(dat, pred_log_odds),
  ncol = 4
  )

We also do the same for proportions.

pred_p <- posterior_epred(nor_pos_mod)[seq(1, num_iter*8, 100),]
grid.arrange(
  grobs = generateProportionPlots(dat, pred_p),
  ncol = 4
  )

Looking at the graphs, we see that the model fit is good.

However, it is not performing very well for certain predictors. For some, the relationship does not seem monotonic. So, next we will fit a new model that will allow quadratic terms for the following variables:

  • bv
  • mvstv
  • mvltv
  • hw
  • hmi
  • hmo
  • hmed
  • hv
  • hme
cols <- c("bv", "mvstv", "mvltv", "hw", "hmi", "hmo", "hmed", "hv", "hme")

dat2 <- dat

for (c in cols) {
  dat2[paste(c, "_sq", sep = "")] <- dat[[c]] ^ 2
}

caret::featurePlot(x = select(dat2, ends_with("_sq")), 
            y = as.factor(dat2$fh),
            plot = "density", 
            scales = list(x = list(relation = "free"), 
                          y = list(relation = "free")), 
            adjust = 1.5, 
            pch = "|", 
            layout = c(3, 3))

Modified Models:

Gaussian

The Model

\[ Likelihood:\\ fh_i|p_i \sim ~ Bernoulli(p_i)\\ \frac{p_i}{1-p_i} := exp({\beta_0 + \sum_{j = 1}^{31}X_{ij}\beta_j})\\ Priors:\\ \beta_0 \sim Normal(0,10)\\ \beta_j \sim Normal(0, 2.5) \\\forall j \in \{1, 2, ... ,31\} \] Below is the specification of the model in rstanarm:

normal_prior_mod <-
  stan_glm(
    fh ~ .,
    data = dat2,
    family = binomial,
    prior_intercept = normal(0, 10),
    prior = normal(0, 2.5),
    chains =  4,
    iter = num_iter * 4,
    seed = 84735,
    prior_PD = TRUE,
    diagnostic_file = file.path(tempdir(), "df2.csv")
  )


normal_mod <- update(normal_prior_mod, prior_PD = FALSE)
plot(normal_prior_mod)

pp_check(normal_prior_mod, plotfun = "bars", nreps = 1000)

Convergence of MCMC chains

mcmc_trace(normal_mod, facet_args = list(ncol = 3))

mcmc_dens_overlay(normal_mod, facet_args = list(ncol = 3))

summary(normal_mod)[,"Rhat"] %>% round(3)
##   (Intercept)            bv           acc            fm            uc 
##         1.000         1.000         1.000         1.000         1.000 
##            ld            sd            pd          astv         mvstv 
##         1.000         1.000         1.000         1.000         1.000 
##       ptwaltv         mvltv            hw           hmi           hma 
##         1.000         1.001         1.001         1.001         1.001 
##           hnp           hnz           hmo           hme          hmed 
##         1.000         1.000         1.001         1.001         1.001 
##            hv            ht         bv_sq      mvstv_sq      mvltv_sq 
##         1.000         1.000         1.000         1.000         1.001 
##         hw_sq        hmi_sq        hmo_sq       hmed_sq         hv_sq 
##         1.000         1.000         1.001         1.001         1.001 
##        hme_sq      mean_PPD log-posterior 
##         1.001         1.000         1.000

The chains have converged.

Lasso

As there are a lot of variables, we have decided to not only use the normal model but use the Bayesian Lasso as well.

The Model

\[ Likelihood:\\ fh_i|p_i \sim ~ Bernoulli(p_i)\\ \frac{p_i}{1-p_i} := exp({\beta_0 + \sum_{j = 1}^{21}X_{ij}\beta_j})\\ Priors:\\ \beta_0 \sim Normal(0,10)\\ \beta_j \sim Laplace(0, 2.5) \\\forall j \in \{1, 2, ... ,30\} \]

Below is the specification of the model in rstanarm:

lasso_prior_mod <-
  stan_glm(
    fh ~ .,
    data = dat2,
    family = binomial,
    prior_intercept = normal(0, 10),
    prior = laplace(0, 2.5),
    chains =  4,
    iter = num_iter * 4,
    seed = 84735,
    prior_PD = TRUE,
    diagnostic_file = file.path(tempdir(), "df3.csv")
  )

lasso_mod <- update(lasso_prior_mod, prior_PD = FALSE)
plot(lasso_prior_mod)

pp_check(lasso_prior_mod, plotfun = "bars", nreps = 1000)

Convergence of MCMC chains

mcmc_trace(lasso_mod, facet_args = list(ncol = 3))

mcmc_dens_overlay(lasso_mod, facet_args = list(ncol = 3))

summary(lasso_mod)[,"Rhat"] %>% round(3)
##   (Intercept)            bv           acc            fm            uc 
##         1.000         1.000         1.000         1.000         1.000 
##            ld            sd            pd          astv         mvstv 
##         1.000         1.000         1.000         1.000         1.000 
##       ptwaltv         mvltv            hw           hmi           hma 
##         1.000         1.000         1.000         1.000         1.000 
##           hnp           hnz           hmo           hme          hmed 
##         1.000         1.000         1.000         1.000         1.000 
##            hv            ht         bv_sq      mvstv_sq      mvltv_sq 
##         1.000         1.000         1.000         1.000         1.001 
##         hw_sq        hmi_sq        hmo_sq       hmed_sq         hv_sq 
##         1.000         1.000         1.000         1.000         1.000 
##        hme_sq      mean_PPD log-posterior 
##         1.000         1.000         1.004

The chains have converged.

Model Comparision

Estimates

cbind(
  X = summary(normal_mod)[1:31,c(1,3:6)] %>% round(digits = 2),
  Y = summary(lasso_mod)[1:31,c(1,3:6)] %>% round(digits = 2)
)
##              mean   sd   10%   50%   90%  mean   sd   10%   50%   90%
## (Intercept)  6.05 0.55  5.35  6.04  6.78  6.08 0.56  5.38  6.06  6.82
## bv           0.06 0.39 -0.44  0.07  0.56  0.10 0.37 -0.38  0.10  0.57
## acc          4.47 0.53  3.80  4.46  5.15  4.53 0.55  3.84  4.52  5.25
## fm          -0.63 0.15 -0.82 -0.64 -0.45 -0.63 0.15 -0.82 -0.64 -0.45
## uc           0.80 0.13  0.63  0.80  0.98  0.80 0.14  0.63  0.80  0.97
## ld           0.72 0.29  0.34  0.71  1.10  0.70 0.29  0.33  0.70  1.08
## sd           0.10 0.13 -0.06  0.10  0.27  0.09 0.13 -0.06  0.09  0.26
## pd          -0.98 0.23 -1.29 -0.98 -0.69 -0.98 0.23 -1.28 -0.98 -0.69
## astv        -1.86 0.23 -2.16 -1.85 -1.57 -1.85 0.23 -2.15 -1.84 -1.56
## mvstv        1.12 0.41  0.60  1.10  1.65  1.10 0.41  0.58  1.10  1.63
## ptwaltv     -0.48 0.14 -0.67 -0.48 -0.30 -0.47 0.14 -0.65 -0.47 -0.29
## mvltv       -0.36 0.31 -0.77 -0.36  0.04 -0.33 0.30 -0.72 -0.33  0.04
## hw          -0.06 1.89 -2.49 -0.04  2.35 -0.10 1.60 -1.97 -0.11  1.74
## hmi         -0.12 1.45 -1.98 -0.12  1.74 -0.15 1.23 -1.59 -0.13  1.28
## hma         -0.32 0.91 -1.48 -0.32  0.85 -0.29 0.77 -1.19 -0.28  0.60
## hnp         -0.36 0.19 -0.61 -0.36 -0.12 -0.36 0.19 -0.60 -0.36 -0.12
## hnz         -0.02 0.12 -0.17 -0.02  0.13 -0.01 0.12 -0.16 -0.02  0.13
## hmo          0.93 0.54  0.24  0.93  1.63  0.89 0.55  0.21  0.88  1.60
## hme         -2.54 0.86 -3.64 -2.54 -1.43 -2.48 0.92 -3.68 -2.45 -1.33
## hmed         0.70 0.82 -0.34  0.69  1.75  0.63 0.79 -0.33  0.58  1.67
## hv          -2.38 0.53 -3.05 -2.38 -1.68 -2.33 0.52 -3.00 -2.33 -1.66
## ht          -0.15 0.18 -0.38 -0.15  0.08 -0.13 0.17 -0.35 -0.13  0.09
## bv_sq        0.86 0.17  0.65  0.86  1.09  0.85 0.16  0.64  0.84  1.06
## mvstv_sq    -0.15 0.08 -0.26 -0.15 -0.04 -0.15 0.08 -0.25 -0.15 -0.04
## mvltv_sq     0.07 0.12 -0.08  0.06  0.23  0.06 0.12 -0.09  0.05  0.22
## hw_sq        0.02 0.19 -0.22  0.02  0.26  0.02 0.18 -0.21  0.02  0.25
## hmi_sq      -0.26 0.19 -0.50 -0.26 -0.02 -0.25 0.19 -0.49 -0.25 -0.02
## hmo_sq      -0.01 0.22 -0.28 -0.02  0.27 -0.02 0.21 -0.29 -0.02  0.24
## hmed_sq     -0.35 0.36 -0.81 -0.35  0.11 -0.35 0.34 -0.80 -0.35  0.07
## hv_sq        0.17 0.15 -0.03  0.18  0.36  0.16 0.15 -0.03  0.16  0.34
## hme_sq      -1.06 0.27 -1.40 -1.06 -0.72 -1.04 0.27 -1.39 -1.03 -0.69
grid.arrange(
  mcmc_dens(normal_mod, facet_args = list(ncol = 1)),
  mcmc_dens(lasso_mod, facet_args = list(ncol = 1)),
  ncol = 2
)

Checking the Posterior Predictive Distribution

grid.arrange(
  pp_check(normal_mod, plotfun = "bars", nreps = 1000),
  pp_check(lasso_mod, plotfun = "bars", nreps = 1000),
  ncol = 2
)

Next, we will look at the fit with respect to log odds for each predictor.

norm_pred_log_odds <- posterior_linpred(normal_mod)[seq(1, num_iter*8, 10),]
lasso_pred_log_odds <- posterior_linpred(lasso_mod)[seq(1, num_iter*8, 10),]

grid.arrange(
  generateLogOddsPlots(dat, norm_pred_log_odds, num_col = 1),
  generateLogOddsPlots(dat, lasso_pred_log_odds, num_col = 1),
  num_col = 2
)

For proportions:

norm_pred_p <- posterior_epred(normal_mod)[seq(1, num_iter*8, 10),]
lasso_pred_p <- posterior_epred(lasso_mod)[seq(1, num_iter*8, 10),]

grid.arrange(
  generateProportionPlots(dat, norm_pred_p, num_col = 1),
  generateProportionPlots(dat, lasso_pred_p, num_col = 1),
  num_col = 2
)

Both the models kinda looks the same only to be honest.

Test using Bayes Factor

library(bridgesampling)
vecLogML <-
  c(
    bridge_sampler(nor_pos_mod)$logml,
    bridge_sampler(normal_mod)$logml,
    bridge_sampler(lasso_mod)$logml
  )

Log Bayes Factor. Column / Row

mat <- matrix(rep(0, 9), nrow = 3)

for (i in 1:3){
  for (j in 1:3){
    mat[i, j] <- vecLogML[j] - vecLogML[i]
  }
}
mat
##           [,1]      [,2]      [,3]
## [1,]   0.00000 30.393609 24.608866
## [2,] -30.39361  0.000000 -5.784744
## [3,] -24.60887  5.784744  0.000000

We can see that model 2 (normal model with square predictors) is the best.

Results on Test Data

First Model

metrics <- accuracyByCutoff(test, nor_pos_mod)
cutoff <- metrics$cutoff[which.max(metrics$balanced_accuracy)]

metrics %>%
  tidyr::pivot_longer(cols = !cutoff,
                      names_to = "metric",
                      values_to = "value") %>%
  ggplot(aes(cutoff, value, color = metric)) +
  geom_line() +
  geom_vline(xintercept = cutoff) +
  ggtitle("Classification Metrics vs. Cutoff")

metrics %>%
  ggplot(aes(specificity, sensitivity)) +
  geom_line(color = "red") +
  ggtitle("Specificity vs. Sensitivity")

metrics %>%
  ggplot(aes(precision, recall)) +
  geom_line(color = "red") +
  ggtitle("Precision vs. Recall")

test %>% 
  tidybayes::add_epred_draws(nor_pos_mod, ndraws = 100) %>% 
  group_by(.row) %>% 
  summarise(fh = mean(fh), pred = mean(.epred)) %>% 
  mutate(
    predicted = as.factor(if_else(pred > cutoff, 1, 0)),
    fh = as.factor(fh)
  ) %>% select(!.row) -> t

caret::confusionMatrix(reference = t$fh, data = t$predicted, mode = "sens_spec", positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  89  13
##          1  12 322
##                                           
##                Accuracy : 0.9427          
##                  95% CI : (0.9165, 0.9626)
##     No Information Rate : 0.7683          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8395          
##                                           
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.9612          
##             Specificity : 0.8812          
##          Pos Pred Value : 0.9641          
##          Neg Pred Value : 0.8725          
##              Prevalence : 0.7683          
##          Detection Rate : 0.7385          
##    Detection Prevalence : 0.7661          
##       Balanced Accuracy : 0.9212          
##                                           
##        'Positive' Class : 1               
## 

Second Model

cols <- c("bv", "mvstv", "mvltv", "hw", "hmi", "hmo", "hmed", "hv", "hme")

test2 <- test

for (c in cols) {
  test2[paste(c, "_sq", sep = "")] <- test[[c]] ^ 2
}

metrics <- accuracyByCutoff(test2, normal_mod)
cutoff <- metrics$cutoff[which.max(metrics$balanced_accuracy)]

metrics %>%
  tidyr::pivot_longer(cols = !cutoff,
                      names_to = "metric",
                      values_to = "value") %>%
  ggplot(aes(cutoff, value, color = metric)) +
  geom_line() +
  geom_vline(xintercept = cutoff) +
  ggtitle("Classification Metrics vs. Cutoff")

metrics %>%
  ggplot(aes(specificity, sensitivity)) +
  geom_line(color = "red") +
  ggtitle("Specificity vs. Sensitivity")

metrics %>%
  ggplot(aes(precision, recall)) +
  geom_line(color = "red") +
  ggtitle("Precision vs. Recall")

test2 %>% 
  tidybayes::add_epred_draws(normal_mod, ndraws = 100) %>% 
  group_by(.row) %>% 
  summarise(fh = mean(fh), pred = mean(.epred)) %>% 
  mutate(
    predicted = as.factor(if_else(pred > cutoff, 1, 0)),
    fh = as.factor(fh)
  ) %>% select(!.row) -> t2

caret::confusionMatrix(t2$predicted, reference = t2$fh, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  98  34
##          1   3 301
##                                           
##                Accuracy : 0.9151          
##                  95% CI : (0.8849, 0.9395)
##     No Information Rate : 0.7683          
##     P-Value [Acc > NIR] : 6.807e-16       
##                                           
##                   Kappa : 0.7847          
##                                           
##  Mcnemar's Test P-Value : 8.140e-07       
##                                           
##             Sensitivity : 0.8985          
##             Specificity : 0.9703          
##          Pos Pred Value : 0.9901          
##          Neg Pred Value : 0.7424          
##              Prevalence : 0.7683          
##          Detection Rate : 0.6904          
##    Detection Prevalence : 0.6972          
##       Balanced Accuracy : 0.9344          
##                                           
##        'Positive' Class : 1               
## 

LASSO Model

metrics <- accuracyByCutoff(test2, lasso_mod)

cutoff <- metrics$cutoff[which.max(metrics$balanced_accuracy)]

metrics %>%
  tidyr::pivot_longer(cols = !cutoff,
                      names_to = "metric",
                      values_to = "value") %>%
  ggplot(aes(cutoff, value, color = metric)) +
  geom_line() +
  geom_vline(xintercept = cutoff) +
  ggtitle("Classification Metrics vs. Cutoff")

metrics %>%
  ggplot(aes(specificity, sensitivity)) +
  geom_line(color = "red") +
  ggtitle("Specificity vs. Sensitivity")

metrics %>%
  ggplot(aes(precision, recall)) +
  geom_line(color = "red") +
  ggtitle("Precision vs. Recall")

test2 %>% 
  tidybayes::add_epred_draws(lasso_mod, ndraws = 100) %>% 
  group_by(.row) %>% 
  summarise(fh = mean(fh), pred = mean(.epred)) %>% 
  mutate(
    predicted = as.factor(if_else(pred > cutoff, 1, 0)),
    fh = as.factor(fh)
  ) %>% select(!.row) -> t2

caret::confusionMatrix(t2$predicted, reference = t2$fh, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  97  31
##          1   4 304
##                                           
##                Accuracy : 0.9197          
##                  95% CI : (0.8901, 0.9435)
##     No Information Rate : 0.7683          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7938          
##                                           
##  Mcnemar's Test P-Value : 1.109e-05       
##                                           
##             Sensitivity : 0.9075          
##             Specificity : 0.9604          
##          Pos Pred Value : 0.9870          
##          Neg Pred Value : 0.7578          
##              Prevalence : 0.7683          
##          Detection Rate : 0.6972          
##    Detection Prevalence : 0.7064          
##       Balanced Accuracy : 0.9339          
##                                           
##        'Positive' Class : 1               
## 

Appendix: Code for Helper Functions

generateProportionPlots <-
  function(dat,
           nor_fh_pred,
           num_col = 2) {
    new_dat <-
      data.frame(lapply(select(dat, !fh), function(x) as.factor(
        arules::discretize(
          x,
          methods = "frequency",
          breaks = 10,
          labels = FALSE,
          infinity = TRUE,
          ordered_result = TRUE
        ))))
    
    useful <-
      data.frame(new_dat,
                 fh = dat$fh,
                 t(nor_fh_pred)) %>%
      tidyr::pivot_longer(cols = starts_with("X"),
                          values_to = "proportion")
    
    plots <- list()
    for (col in names(new_dat)) {
      p <- useful %>%
        group_by_(col) %>%
        mutate(m = mean(fh)) %>%
        ggplot(aes_string(col)) +
        geom_boxplot(aes(y = proportion), outlier.shape = NA) +
        geom_point(aes(y = m), color = "firebrick", shape = 8)
      
      plots[[col]] <- p
    }
    return(plots)
  }

generateLogOddsPlots <-
  function(dat,
           nor_fh_pred,
           num_col = 2) {
    new_dat <-
      data.frame(lapply(select(dat, !fh), function(x) as.factor(
        arules::discretize(
          x,
          methods = "frequency",
          breaks = 10,
          labels = FALSE,
          infinity = TRUE,
          ordered_result = TRUE
        ))))
    
    useful <-
      data.frame(new_dat,
                 fh = dat$fh,
                 t(nor_fh_pred)) %>%
      tidyr::pivot_longer(cols = starts_with("X"),
                          values_to = "log_odds")
    
    plots <- list()
    for (col in names(new_dat)) {
      p <- useful %>%
        group_by_(col) %>%
        mutate(m = log(mean(fh) / (1 - mean(fh)))) %>%
        ggplot(aes_string(col)) +
        geom_boxplot(aes(y = log_odds), outlier.shape = NA) +
        geom_point(aes(y = m), color = "firebrick", shape = 8)
      
      plots[[col]] <- p
    }
    return(plots)
  }

accuracyByCutoff <- function(dat, model) {
  n <- 100
  cutoff <- c(0:n)/(n + 1)
  accuracy <- rep(0, n+1)
  sensitivity <- rep(0, n+1)
  specificity <- rep(0, n+1)
  precision <- rep(0, n+1)
  recall <- rep(0, n+1)
  f1 <- rep(0, n+1)
  dat %>%
    tidybayes::add_epred_draws(model, ndraws = 100) %>%
    group_by(.row) %>%
    summarise(fh = mean(fh), pred = mean(.epred)) -> t2 
  
  i=1
  for (cut in cutoff) {
    t2 %>%
    mutate(predicted = as.factor(if_else(pred > cut, 1, 0)),
           fh = as.factor(fh)) %>% select(!.row) -> t3
    a <- caret::confusionMatrix(t3$predicted, reference = t3$fh, positive = "1")
    accuracy[i] <- a[["byClass"]][["Balanced Accuracy"]]
    sensitivity[i] <- a[["byClass"]][["Sensitivity"]]
    specificity[i] <- a[["byClass"]][["Specificity"]]
    precision[i] <- a[["byClass"]][["Precision"]]
    recall[i] <- a[["byClass"]][["Recall"]]
    f1[i] <- a[["byClass"]][["F1"]]
    i = i+1 
  }

  
  a <- tibble(cutoff, balanced_accuracy = accuracy, sensitivity, specificity, precision, recall, f1)
  
  return(a)
}